{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from transformers import AutoModelForMaskedLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the model with Transformers and Torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = [\n",
    "    \"Hello World\",\n",
    "    \"Built by Nirant Kasliwal\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch Code from the [SPLADERunner](https://github.com/PrithivirajDamodaran/SPLADERunner) library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_token = \"<your_hf_token_here>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output Logits shape:  torch.Size([2, 10, 30522])\n",
      "Output Attention mask shape:  torch.Size([2, 10])\n",
      "Sparse Vector shape:  torch.Size([2, 30522])\n",
      "SPLADE BOW rep for sentence:\tBuilt by Nirant Kasliwal\n",
      "[('##rant', 2.02), ('built', 1.94), ('##wal', 1.79), ('##sl', 1.69), ('build', 1.57), ('ka', 1.4), ('ni', 1.26), ('made', 0.93), ('architect', 0.76), ('was', 0.69), ('who', 0.61), ('his', 0.5), ('wrote', 0.47), ('india', 0.45), ('company', 0.41), ('##i', 0.41), ('he', 0.37), ('manufacturer', 0.36), ('by', 0.35), ('engineer', 0.33), ('architecture', 0.33), ('ko', 0.23), ('him', 0.22), ('invented', 0.19), ('said', 0.14), ('k', 0.11), ('man', 0.11), ('statue', 0.11), ('bomb', 0.1), ('##wa', 0.1), ('builder', 0.09), ('.', 0.07), ('started', 0.06), (',', 0.04), ('ku', 0.03)]\n"
     ]
    }
   ],
   "source": [
    "# Download the model and tokenizer\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"prithivida/Splade_PP_en_v1\", token=hf_token)\n",
    "reverse_voc = {v: k for k, v in tokenizer.vocab.items()}\n",
    "model = AutoModelForMaskedLM.from_pretrained(\"prithivida/Splade_PP_en_v1\", token=hf_token)\n",
    "model.to(device)\n",
    "\n",
    "# Tokenize the input\n",
    "inputs = tokenizer(sentences, return_tensors=\"pt\", padding=True, truncation=True, max_length=512)\n",
    "inputs = {key: val.to(device) for key, val in inputs.items()}\n",
    "input_ids = inputs[\"input_ids\"]\n",
    "attention_mask = inputs[\"attention_mask\"]\n",
    "token_type_ids = inputs[\"token_type_ids\"]\n",
    "\n",
    "# Run model and prepare sparse vector\n",
    "outputs = model(**inputs)\n",
    "logits = outputs.logits\n",
    "print(\"Output Logits shape: \", logits.shape)\n",
    "print(\"Output Attention mask shape: \", attention_mask.shape)\n",
    "relu_log = torch.log(1 + torch.relu(logits))\n",
    "weighted_log = relu_log * attention_mask.unsqueeze(-1)\n",
    "max_val, _ = torch.max(weighted_log, dim=1)\n",
    "vector = max_val.squeeze()\n",
    "print(\"Sparse Vector shape: \", vector.shape)\n",
    "# print(\"Number of Actual Dimensions: \", len(cols))\n",
    "cols = [vec.nonzero().squeeze().cpu().tolist() for vec in vector]\n",
    "weights = [vec[col].cpu().tolist() for vec, col in zip(vector, cols)]\n",
    "\n",
    "idx = 1\n",
    "cols, weights = cols[idx], weights[idx]\n",
    "# Print the BOW representation\n",
    "d = {k: v for k, v in zip(cols, weights)}\n",
    "sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}\n",
    "bow_rep = []\n",
    "for k, v in sorted_d.items():\n",
    "    bow_rep.append((reverse_voc[k], round(v, 2)))\n",
    "print(f\"SPLADE BOW rep for sentence:\\t{sentences[idx]}\\n{bow_rep}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export with output_attentions and logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exporting model to models/nirantk_SPLADE_PP_en_v1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('models/nirantk_SPLADE_PP_en_v1/tokenizer_config.json',\n",
       " 'models/nirantk_SPLADE_PP_en_v1/special_tokens_map.json',\n",
       " 'models/nirantk_SPLADE_PP_en_v1/vocab.txt',\n",
       " 'models/nirantk_SPLADE_PP_en_v1/added_tokens.json',\n",
       " 'models/nirantk_SPLADE_PP_en_v1/tokenizer.json')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_id = \"nirantk/SPLADE_PP_en_v1\"\n",
    "output_dir = f\"models/{model_id.replace('/', '_')}\"\n",
    "model_kwargs = {\"output_attentions\": True, \"return_dict\": True}\n",
    "\n",
    "print(f\"Exporting model to {output_dir}\")\n",
    "tokenizer.save_pretrained(output_dir)\n",
    "# main_export(\n",
    "#     model_id,\n",
    "#     output=output_dir,\n",
    "#     no_post_process=True,\n",
    "#     model_kwargs=model_kwargs,\n",
    "#     token=hf_token,\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the model with ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from optimum.onnxruntime import ORTModelForMaskedLM\n",
    "\n",
    "model = ORTModelForMaskedLM.from_pretrained(\"nirantk/SPLADE_PP_en_v1\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"nirantk/SPLADE_PP_en_v1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tokenizer(sentences, return_tensors=\"pt\", padding=True, truncation=True, max_length=512)\n",
    "inputs = {key: val.to(device) for key, val in inputs.items()}\n",
    "input_ids = inputs[\"input_ids\"]\n",
    "attention_mask = inputs[\"attention_mask\"]\n",
    "token_type_ids = inputs[\"token_type_ids\"]\n",
    "\n",
    "onnx_input = {\n",
    "    \"input_ids\": input_ids.cpu().numpy(),\n",
    "    \"attention_mask\": attention_mask.cpu().numpy(),\n",
    "    \"token_type_ids\": token_type_ids.cpu().numpy(),\n",
    "}\n",
    "\n",
    "logits = model(**onnx_input).logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 10, 30522)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output Logits shape:  (2, 10, 30522)\n",
      "Sparse Vector shape:  (2, 30522)\n",
      "SPLADE BOW rep for sentence:\tBuilt by Nirant Kasliwal\n",
      "[('##rant', 2.02), ('built', 1.94), ('##wal', 1.79), ('##sl', 1.69), ('build', 1.57), ('ka', 1.4), ('ni', 1.26), ('made', 0.93), ('architect', 0.76), ('was', 0.69), ('who', 0.61), ('his', 0.5), ('wrote', 0.47), ('india', 0.45), ('company', 0.41), ('##i', 0.41), ('he', 0.37), ('manufacturer', 0.36), ('by', 0.35), ('engineer', 0.33), ('architecture', 0.33), ('ko', 0.23), ('him', 0.22), ('invented', 0.19), ('said', 0.14), ('k', 0.11), ('man', 0.11), ('statue', 0.11), ('bomb', 0.1), ('##wa', 0.1), ('builder', 0.09), ('.', 0.07), ('started', 0.06), (',', 0.04), ('ku', 0.03)]\n"
     ]
    }
   ],
   "source": [
    "print(\"Output Logits shape: \", logits.shape)\n",
    "\n",
    "relu_log = np.log(1 + np.maximum(logits, 0))\n",
    "\n",
    "# Equivalent to relu_log * attention_mask.unsqueeze(-1)\n",
    "# For NumPy, you might need to explicitly expand dimensions if 'attention_mask' is not already 2D\n",
    "weighted_log = relu_log * np.expand_dims(attention_mask, axis=-1)\n",
    "\n",
    "# Equivalent to torch.max(weighted_log, dim=1)\n",
    "# NumPy's max function returns only the max values, not the indices, so we don't need to unpack two values\n",
    "max_val = np.max(weighted_log, axis=1)\n",
    "\n",
    "# Equivalent to max_val.squeeze()\n",
    "# This step may be unnecessary in NumPy if max_val doesn't have unnecessary dimensions\n",
    "vector = np.squeeze(max_val)\n",
    "print(\"Sparse Vector shape: \", vector.shape)\n",
    "\n",
    "# print(vector[0].nonzero())\n",
    "\n",
    "cols = [vec.nonzero()[0].squeeze().tolist() for vec in vector]\n",
    "weights = [vec[col].tolist() for vec, col in zip(vector, cols)]\n",
    "\n",
    "idx = 1\n",
    "cols, weights = cols[idx], weights[idx]\n",
    "# Print the BOW representation\n",
    "d = {k: v for k, v in zip(cols, weights)}\n",
    "sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}\n",
    "bow_rep = []\n",
    "for k, v in sorted_d.items():\n",
    "    bow_rep.append((reverse_voc[k], round(v, 2)))\n",
    "print(f\"SPLADE BOW rep for sentence:\\t{sentences[idx]}\\n{bow_rep}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "35"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1010,\n",
       " 1012,\n",
       " 1047,\n",
       " 2001,\n",
       " 2002,\n",
       " 2010,\n",
       " 2011,\n",
       " 2032,\n",
       " 2040,\n",
       " 2056,\n",
       " 2072,\n",
       " 2081,\n",
       " 2158,\n",
       " 2194,\n",
       " 2318,\n",
       " 2328,\n",
       " 2626,\n",
       " 2634,\n",
       " 3857,\n",
       " 3992,\n",
       " 4213,\n",
       " 4294,\n",
       " 4944,\n",
       " 5968,\n",
       " 6231,\n",
       " 7751,\n",
       " 8826,\n",
       " 9152,\n",
       " 10556,\n",
       " 12508,\n",
       " 12849,\n",
       " 13476,\n",
       " 13970,\n",
       " 14540,\n",
       " 17884]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fst",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
